import io
import sys
import pickle
import pathlib
import zipfile
import warnings
import operator
from enum import Enum
from functools import reduce, wraps
from typing import Dict, Union, Optional, cast
 
import torch
import numpy as np
import mindspore
from ml_dtypes import bfloat16
 
def _is_path(name_or_buffer):
    return isinstance(name_or_buffer, (str, pathlib.Path))
 
class _opener:
    def __init__(self, file_like):
        self.file_like = file_like
 
    def __enter__(self):
        return self.file_like
 
    def __exit__(self, *args):
        pass
 
class _open_file(_opener):
    def __init__(self, name, mode):
        super().__init__(open(name, mode))
 
    def __exit__(self, *args):
        self.file_like.close()
 
class _open_buffer_writer(_opener):
    def __exit__(self, *args):
        self.file_like.flush()
 
class _open_buffer_reader(_opener):
    def __init__(self, buffer):
        super().__init__(buffer)
        _check_seekable(buffer)
 
def _check_seekable(f) -> bool:
 
    def raise_err_msg(patterns, e):
        for p in patterns:
            if p in str(e):
                msg = (str(e) + ". You can only torch.load from a file that is seekable."
                                + " Please pre-load the data into a buffer like io.BytesIO and"
                                + " try to load from it instead.")
                raise type(e)(msg)
        raise e
 
    try:
        f.seek(f.tell())
        return True
    except (io.UnsupportedOperation, AttributeError) as e:
        raise_err_msg(["seek", "tell"], e)
    return False
 
 
def _open_file_like(name_or_buffer, mode):
    if _is_path(name_or_buffer):
        return _open_file(name_or_buffer, mode)
    else:
        if 'w' in mode:
            return _open_buffer_writer(name_or_buffer)
        elif 'r' in mode:
            return _open_buffer_reader(name_or_buffer)
        else:
            raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
 
def _is_zipfile(f) -> bool:
    """
    Args:
        f (file object): The file object to be checked for being a valid zip file.
            It should be opened in binary mode and point to the beginning of the file.
 
    Returns:
        bool: Returns True if the input file is a valid zip file, otherwise False.
 
    Raises:
        No specific exceptions are raised by this function.
    """
    # This is a stricter implementation than zipfile.is_zipfile().
    # zipfile.is_zipfile() is True if the magic number appears anywhere in the
    # binary. Since we expect the files here to be generated by torch.save or
    # torch.jit.save, it's safe to only check the start bytes and avoid
    # collisions and assume the zip has only 1 file.
    # See bugs.python.org/issue28494.
 
    # Read the first 4 bytes of the file
    read_bytes = []
    start = f.tell()
 
    byte = f.read(1)
    while byte != b"":
        read_bytes.append(byte)
        if len(read_bytes) == 4:
            break
        byte = f.read(1)
    f.seek(start)
 
    local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
    return read_bytes == local_header_magic_number
 
 
class PyTorchFileReader:
    """
    Class to allow PackageImporter to operate on unzipped packages. Methods
    copy the behavior of the internal PyTorchFileReader class (which is used for
    accessing packages in all other cases).
 
    N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
    class due to ScriptObjects requiring an actual PyTorchFileReader instance.
    """
    def __init__(self, file):
        """
        Initializes a new instance of PyTorchFileReader.
 
        Args:
            self (PyTorchFileReader): The instance of the PyTorchFileReader class.
            file (str): The path to the zip file to be read.
 
        Returns:
            None. This method initializes the PyTorchFileReader instance with the provided file.
 
        Raises:
            IOError: If the file specified by the 'file' parameter does not exist or cannot be opened.
            zipfile.BadZipFile: If the file specified by the 'file' parameter is not a valid zip file.
            IndexError: If the zip file does not contain any files.
        """
 
        self.file = zipfile.ZipFile(file)
        if hasattr(file, 'offset'):
            file.seek(0)
            bytes = file.read(file.len)
            bytes = io.BytesIO(bytes)
            self.file = zipfile.ZipFile(bytes)
 
        self.directory = self.file.namelist()[0].split('/')[0]
 
    def open_record(self, name):
        """
        Opens a record file from the PyTorchFileReader directory.
 
        Args:
            self (PyTorchFileReader): The instance of the PyTorchFileReader class.
            name (str): The name of the record file to open.
 
        Returns:
            None: If the specified record file does not exist in the PyTorchFileReader directory.
 
        Raises:
            None.
 
        This method checks if the specified record file exists in the PyTorchFileReader directory. If it does, the file is opened and returned. If the file does not exist, None is returned.
        """
        filename = f"{self.directory}/{name}"
        if filename in self.file.namelist():
            return self.file.open(filename)
        return None
 
    def read_record(self, name):
        """
        Reads a record from a PyTorch file.
 
        Args:
            self (PyTorchFileReader): An instance of the PyTorchFileReader class.
            name (str): The name of the record to read from the PyTorch file.
 
        Returns:
            None: If the record with the specified name does not exist in the PyTorch file.
 
        Raises:
            FileNotFoundError: If the PyTorch file does not exist in the specified directory.
            IOError: If there is an error in reading the PyTorch file.
 
        """
        filename = f"{self.directory}/{name}"
        if filename in self.file.namelist():
            return self.file.read(filename)
        return None
 
    def has_record(self, name):
        """
        This method checks if a record with the specified name exists in the PyTorchFileReader's directory.
 
        Args:
            self (PyTorchFileReader): An instance of the PyTorchFileReader class.
            name (str): The name of the record to be checked in the directory.
 
        Returns:
            None: This method returns None.
 
        Raises:
            None
        """
        filename = f"{self.directory}/{name}"
        return filename in self.file.namelist()
 
    def get_all_records(
        self,
    ):
        """
        Retrieves a list of all records from the PyTorchFileReader object.
 
        Args:
            self: The PyTorchFileReader object itself.
 
        Returns:
            None. This method does not return any value.
 
        Raises:
            None.
 
        This method iterates through the files in the PyTorchFileReader object's directory and retrieves the names of all records. The records are then returned as a list of file names.
 
        Note:
            - The PyTorchFileReader object must be initialized with a valid directory.
            - The list of file names returned only includes the names of the files, without the directory path.
        """
        files = [name.replace(self.directory + '/' , '')for name in self.file.namelist()]
        return files
 
    def get_record_offset(self, name):
        """
        Returns the header offset of a specified record in a PyTorch file.
 
        Args:
            self (PyTorchFileReader): An instance of the PyTorchFileReader class.
            name (str): The name of the record for which the header offset is to be retrieved.
 
        Returns:
            None: If the specified record does not exist in the PyTorch file.
 
        Raises:
            None.
 
        This method takes in the self parameter, which is an instance of the PyTorchFileReader class. It also takes a name parameter, which represents the name of the record for which the header offset is to
be retrieved. The method checks if the specified record exists in the PyTorch file by creating the filename using the directory attribute of the PyTorchFileReader instance and the provided name. If the
filename exists in the file's namelist, the method returns the header offset of the file info associated with the filename. Otherwise, it returns None, indicating that the specified record does not exist in
the file.
        """
        filename = f"{self.directory}/{name}"
        if filename in self.file.namelist():
            return self.file.getinfo(filename).header_offset
        return None
 
 
class _open_zipfile_reader(_opener):
 
    """
    The _open_zipfile_reader class represents a reader for opening and reading zip files.
    It inherits from the _opener class and provides functionality for reading zip files.
 
    Attributes:
        name_or_buffer: The name or buffer of the file to be opened.
 
    Methods:
        __init__: Initializes the _open_zipfile_reader instance, using the specified name_or_buffer to open a PyTorchFileReader.
    """
    def __init__(self, name_or_buffer) -> None:
        """
        Initializes the _open_zipfile_reader class.
 
        Args:
            self (object): The instance of the _open_zipfile_reader class.
            name_or_buffer (str or file-like object): The name of the file or a buffer object for reading the zipfile.
                It can be a string representing the name of the file or a file-like object for reading the zipfile data.
 
        Returns:
            None: This method does not return any value.
 
        Raises:
            - TypeError: If the name_or_buffer parameter is not a string or file-like object.
            - ValueError: If the name_or_buffer parameter is empty or invalid.
            - IOError: If there is an error reading the zipfile from the provided name_or_buffer.
        """
        super().__init__(PyTorchFileReader(name_or_buffer))
 
def _is_torchscript_zip(zip_file):
    """
    Checks if the given zip file contains a specific record.
 
    Args:
        zip_file (object): The zip file to be checked for the presence of a specific record.
 
    Returns:
        None: This function does not return any value.
 
    Raises:
        None
    """
    return 'constants.pkl' in zip_file.get_all_records()
 
 
class LoadEndianness(Enum):
 
    """
    Represents an enumeration for specifying the byte order (endianness) of a data load.
 
    This class inherits from the built-in Enum class in Python and provides a set of pre-defined constants for different byte orders. The byte order determines the arrangement of bytes in a multi-byte data
type, such as integers and floating-point numbers, when it is stored or transmitted.
 
    Attributes:
        BIG_ENDIAN: Represents the big-endian byte order where the most significant byte is stored first.
        LITTLE_ENDIAN: Represents the little-endian byte order where the least significant byte is stored first.
        NATIVE: Represents the native byte order of the underlying platform.
        NETWORK: Represents the byte order used in network byte order, which is big-endian.
 
    The LoadEndianness class allows you to easily specify the desired byte order when loading data, ensuring compatibility with the expected byte order. It provides a convenient and readable way to work with
different byte orders without the need for manual byte swapping or conversion.
 
    Usage:
        The LoadEndianness class can be used to specify the byte order when loading data from a file, network, or any other data source. Simply import the class and use the desired constant to set the byte
order.
 
    Example:
        >>> load_endianness = LoadEndianness.BIG_ENDIAN
        >>> data = load_data(source_file, byte_order=load_endianness)
        >>> print(data)
 
    Note:
        It is important to ensure that the byte order specified matches the actual byte order of the data being loaded. Using the wrong byte order can lead to incorrect interpretation of the data and produce
unexpected results.
 
    """
    NATIVE = 1
    LITTLE = 2
    BIG = 3
 
_default_load_endian: Optional[LoadEndianness] = None
 
def get_default_load_endianness() -> Optional[LoadEndianness]:
    '''
    Get fallback byte order for loading files
 
    If byteorder mark is not present in saved checkpoint,
    this byte order is used as fallback.
    By default, it's "native" byte order.
 
    Returns:
        default_load_endian: Optional[LoadEndianness]
    '''
    return _default_load_endian
 
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
    """
    This function decodes a bytes string to ASCII if it is a bytes type, otherwise returns the input string.
 
    Args:
        bytes_str (Union[bytes, str]): A bytes or string input to be decoded if it is a bytes type. If it is already a string, it will be returned as is.
 
    Returns:
        str: The decoded ASCII string if the input is of bytes type, otherwise the original string.
 
    Raises:
        None
    """
    # When using encoding='bytes' in Py3, some **internal** keys stored as
    # strings in Py2 are loaded as bytes. This function decodes them with
    # ascii encoding, one that Py3 uses by default.
    #
    # NOTE: This should only be used on internal keys (e.g., `typename` and
    #       `location` in `persistent_load` below!
    if isinstance(bytes_str, bytes):
        return bytes_str.decode('ascii')
    return bytes_str
 
dtype_map = {
    "HalfStorage": np.float16,
    "FloatStorage": np.float32,
    'BFloat16Storage': bfloat16,
    'LongStorage': np.int64,
    'ByteStorage': np.uint8,
    'BoolStorage': np.bool_
}
 
 
def load_ms_weights(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args):
    """
    Load a file using pickle, optionally with memory mapping.
 
    Args:
        f (file-like object or str): The file to load from. If a string is provided, it should be the filename.
        pickle_module (module): The module to use for pickling. Defaults to the standard 'pickle' module.
 
    Returns:
        None: This function does not return any value.
 
    Raises:
        ValueError: Raised if 'f' is not a string filename when using mmap argument, or if torchscript is detected in a zipfile.
        RuntimeError: Raised if mmap argument is used without files saved with `torch.save(_use_new_zipfile_serialization=True)`.
    """
    if pickle_module is None:
        pickle_module = pickle
 
    if 'encoding' not in pickle_load_args:
        pickle_load_args['encoding'] = 'utf-8'
 
    with _open_file_like(f, 'rb') as opened_file:
        if _is_zipfile(opened_file):
            # The zipfile reader is going to advance the current file position.
            # If we want to actually tail call to torch.jit.load, we need to
            # reset back to the original position.
            overall_storage = None
            with _open_zipfile_reader(opened_file, ) as opened_zipfile:
                if _is_torchscript_zip(opened_zipfile):
                    raise ValueError('do not support torchscript now')
                return _load(opened_zipfile,
                                map_location,
                                pickle_module,
                                overall_storage=overall_storage,
                                **pickle_load_args)
 
 
func_name_dict = {
    '_rebuild_from_type_v2': _rebuild_from_type_v2,
    '_rebuild_tensor_v2': _rebuild_tensor_v2,
}

def get_func_by_name(name: str):
    func = func_name_dict.get(name)
    if func is None:
        raise RuntimeError(f"load checkpoints failed, function name '{name}' is invalid.")
    return func


def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
    """
    Loads data from a zip file using pickle serialization.
 
    Args:
        zip_file (zipfile.ZipFile): The zip file containing the data.
        pickle_module (module): The pickle module to use for deserialization.
        overall_storage (numpy.memmap, optional): The overall storage for loading the data.
        pickle_file (str, optional): The name of the pickle file within the zip file. Default is 'data.pkl'.
        **pickle_load_args: Additional keyword arguments to pass to the pickle module's load function.
 
    Returns:
        None
 
    Raises:
        ValueError: If an unknown endianness type is encountered.
        ValueError: If an invalid load endianness type is encountered.
        UserWarning: If the default load endianness is changed on big endian machines.
 
    """
    loaded_storages = {}
    # check if byteswapping is needed
    byteordername = 'byteorder'
    byteorderdata = None
    if zip_file.has_record(byteordername):
        byteorderdata = zip_file.read_record(byteordername)
        if byteorderdata not in [b'little', b'big']:
            raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
    elif get_default_load_endianness() == LoadEndianness.LITTLE or \
            get_default_load_endianness() is None:
        byteorderdata = b'little'
    elif get_default_load_endianness() == LoadEndianness.BIG:
        byteorderdata = b'big'
    elif get_default_load_endianness() == LoadEndianness.NATIVE:
        pass
    else:
        raise ValueError('Invalid load endianness type')
 
    if not zip_file.has_record(byteordername) and \
            get_default_load_endianness() is None and \
            sys.byteorder == 'big':
        # Default behaviour was changed
        # See https://github.com/pytorch/pytorch/issues/101688
        warnings.warn("The default load endianness for checkpoints without a byteorder mark "
                      "on big endian machines was changed from 'native' to 'little' endian, "
                      "to avoid this behavior please use "
                      "torch.serialization.set_default_load_endianness to set "
                      "the desired default load endianness",
                      UserWarning)
 
    def persistent_load(saved_id):
        if not isinstance(saved_id, tuple):
            raise TypeError(f"saved_id must be a tuple, get {type(saved_id).__name__}")
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        if typename != 'storage':
            raise ValueError(f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'")
        storage_type, key, location, numel = data
 
        name = f'data/{key}'
        if name in loaded_storages:
            return loaded_storages[name]
 
        if overall_storage is not None:
            array = np.memmap(overall_storage, dtype=dtype_map[storage_type], 
                              offset=zip_file.open_record(name)._fileobj.tell(), shape=(numel,))
        else:
            array = np.frombuffer(zip_file.read_record(name), dtype_map[storage_type])
        loaded_storages[name] = array
        return array
 
    load_module_mapping: Dict[str, str] = {
        # See https://github.com/pytorch/pytorch/pull/51633
        'torch.tensor': 'torch._tensor'
    }
 
    # Need to subclass Unpickler instead of directly monkey-patching the find_class method
    # because it's marked readonly in pickle.
    # The type: ignore is because mypy can't statically determine the type of this class.
    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
        # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
        # Lets us override the imports that pickle uses when unpickling an object.
        # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
        def find_class(self, mod_name, name):
            if mod_name == 'torch._utils':
                return get_func_by_name(name)
            if mod_name == 'torch':
                return str(name)
            if mod_name == 'torch._tensor':
                return get_func_by_name(name)
            mod_name = load_module_mapping.get(mod_name, mod_name)
            return super().find_class(mod_name, name)
 
    # Load the data (which may in turn use `persistent_load` to load tensors)
    data_file = zip_file.open_record(pickle_file)
 
    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
    unpickler.persistent_load = persistent_load
    result = unpickler.load()
    result = transform_ms_dtype_to_pt_dtype(result)
 
    return result
 
DTYPE_MAP = {
    mindspore.float32: torch.float32,
    mindspore.bfloat16: torch.bfloat16
}
 
def transform_ms_dtype_to_pt_dtype(state):
    if isinstance(state, dict):
        new_state_dict = {}
        for k, v in state.items():
            new_key = k
            v = transform_ms_dtype_to_pt_dtype(v)
            if isinstance(k, tuple) and len(k) == 2:
                new_key = []
                for ms_dtype in k:
                    pt_dtype = DTYPE_MAP.get(ms_dtype)
                    if pt_dtype is None:
                        raise ValueError(f"convert error, unsupported dtype {ms_dtype}")
                    new_key.append(pt_dtype)
                new_key = tuple(new_key)
            new_state_dict[new_key] = v
        return new_state_dict
    elif isinstance(state, list):
        new_state_list = []
        for member in state:
            new_state_list.append(transform_ms_dtype_to_pt_dtype(member))
        return new_state_list
    else:
        return state
 
 
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
    '''Rebuilds a tensor based on the provided parameters.
 
    Args:
        storage (ndarray): The storage array from which the tensor is created.
        storage_offset (int): The offset in the storage array from where the tensor data starts.
        size (tuple): The size of the tensor.
        stride (tuple or None): The stride of the tensor, or None if not applicable.
        requires_grad (bool): Indicates if the tensor requires gradient computation.
        backward_hooks (list): A list of backward hooks for the tensor.
        metadata (Any, optional): Additional metadata associated with the tensor.
 
    Returns:
        None: This function does not have a return value.
 
    Raises:
        None: This function does not raise any exceptions.
    '''
 
    if size == ():
        num_elemets = 1
    else:
        num_elemets = reduce(operator.mul, size)
    array = storage[storage_offset: storage_offset + num_elemets]
 
    if stride is not None and len(stride) > 1 and stride[0] == 1:
        order = "F"
        array = array.reshape(size, order=order)
    else:
        order = "C"
        array = array.reshape(size, order=order)
    if array.dtype == bfloat16:
        param = torch.frombuffer(array.tobytes(), dtype=torch.bfloat16).reshape(array.shape)
    else:
        param = torch.from_numpy(array)
    return param
 
 
def _rebuild_from_type_v2(func, new_type, args, state):
    ret = func(*args)
    return ret
 
 
if __name__ == "__main__":
    state_dict = load(".pt")
 
    def recursive_print(state, prefix=None):
        if isinstance(state, dict):
            for k, v in state.items():
                prefix.append(str(k))
                recursive_print(v, prefix)
        elif isinstance(state, list):
            for i, member in enumerate(state):
                prefix.append(str(i))
                recursive_print(member, prefix)
        elif isinstance(state, torch.Tensor):
            state_name = ".".join(prefix)
            print(f"{state_name} {state.dtype} {state.size()} {state.sum()}", flush=True)
        else:
            state_name = ".".join(prefix)
            print(f"{state_name} {type(state)} {state}", flush=True)
    recursive_print(state_dict, [])
